#!/usr/bin/env python


# ---- IMPORT MODULES

import numpy as np

import re
from Hive import Hive
from Hive import Utilities
import data_utils
import pandas as pd
import bnlearn as bn
from colorama import Fore
import fire

from Hive.DAG_update import DAG_Utils, DG_Utils, DataConversion
from pgmpy.estimators import BDeuScore, K2Score, BicScore
from pgmpy.models import BayesianNetwork
from data_utils import best_vector
import json
import os
import time
from gen_plot import loss_plot_simple
import argparse
import random

random.seed(0)


class Evaluation:
    def __init__(self, idx, mode="D3", evaluator="bic", size=10):
        self.idx = idx
        self.mode = mode
        self.evaluator = evaluator
        if size == 100:
            mat_data = data_utils.read_size_100_time_series_data(idx, mode)
        elif size == 10:
            mat_data = data_utils.general_read_Dream4_data(idx, self.mode)
        self.data = pd.DataFrame(
            mat_data, columns=data_utils.DataName.get_column_names(self.mode, size)
        )
        self.ground_truth_vector = self._read_ground_truth({"size": size, "idx": idx})

        if evaluator == "k2":
            self.score = K2Score(self.data)
        elif evaluator == "bic":
            self.score = BicScore(self.data)
        elif evaluator == "bde":
            self.score = BDeuScore(self.data)

    def evaluator_fun(self, vector):
        edges = DataConversion.general_vector2edges(vector, self.mode)
        score = self._get_score(edges)
        return score

    def valid_fun(self, vector):
        edges = DataConversion.general_vector2edges(vector, self.mode)
        try:
            model = BayesianNetwork(edges)
        except:
            return False
        return True

    def _get_score(self, edges):
        try:
            model = BayesianNetwork(edges)
        except:
            return 1e12
        bic_score = self.score.score(model)
        return bic_score

    def _read_ground_truth(self, args):
        size = args["size"]
        file_name = f"DREAM4 in silico challenge/Size {size}/DREAM4 gold standards/insilico_size{size}_{args['idx']}_goldstandard.tsv"
        print(f"{Fore.GREEN}Reading Ground Truth{Fore.RESET} : {file_name}")
        df = pd.read_csv(
            file_name, sep="\t", header=None, names=["source", "target", "value"]
        )
        # 获取所有的节点
        nodes = [f"G{i}" for i in range(1, size + 1)]
        # 创建一个空的邻接矩阵
        adj_matrix = pd.DataFrame(
            np.zeros((len(nodes), len(nodes)), dtype=int), index=nodes, columns=nodes
        )

        # 填充邻接矩阵
        for _, row in df.iterrows():
            adj_matrix.loc[row["source"], row["target"]] = row["value"]
        adj_matrix = adj_matrix.values
        ground_vector = DataConversion.matrix2vector(adj_matrix)
        return ground_vector


# ---- SOLVE TEST CASE WITH ARTIFICIAL BEE COLONY ALGORITHM
GLOBAL_DATA = pd.DataFrame(
    data_utils.read_Dream4_time_series_data(1),
    columns=data_utils.DataName.D3ColumnNames,
)


def D3evaluator(vector):
    edges = DataConversion.general_vector2edges(vector, "D3")
    try:
        model = BayesianNetwork(edges)
    except:
        return 1e12
    bic_score = BicScore(GLOBAL_DATA).score(model)
    return bic_score


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--bee_num", type=int, default=20)
    parser.add_argument("--max_itrs", type=int, default=200)
    parser.add_argument("--mode", type=str, default="D3", help="D3,D1,Real")
    parser.add_argument("--evaluator", type=str, default="bic", help="k2,bic,bde")
    parser.add_argument(
        "--mutation",
        type=str,
        default="mutation",
        help="mutation,cross_and_mutation,cross",
    )
    parser.add_argument("--net_idx", type=int, default=3, help="1,2,3,4,5")
    parser.add_argument(
        "--init_alpha", type=float, default=0.5, help="0.5,0.6,0.7,0.8,0.9"
    )
    parser.add_argument(
        "--init_edges",
        type=str,
        default="",
        help="Path to the file containing initial edges",
    )
    parser.add_argument("--task_name", type=str, default="")
    args = parser.parse_args()
    return args


def experiments_run(args, new_args={}):
    # args = get_args()
    # args.bee_num = 2
    # args.max_itrs = 5
    # args.mode = "D4"
    # args.mutation = "cross"
    # args.mutation = "cross_and_mutation"
    # args.mutation = "mutation"
    # args.init_edges = "output/data/edges_3.json"
    # args.task_name = "no_init_edges"
    print(Fore.GREEN, args, Fore.RESET)
    if args.init_edges == "":
        edges = None
    else:
        with open(args.init_edges, "r") as f:
            edges = json.load(f)
    dim_num = int(args.mode[1])
    ndim = int(new_args["net_size"] * new_args["net_size"] * dim_num)

    evaluation = Evaluation(
        idx=new_args["net_idx"],
        mode=new_args["mode"],
        evaluator=new_args["evaluator"],
        size=new_args["net_size"],
    )
    model = Hive.BeeHive(
        lower=[0.0] * ndim,
        upper=[1.0] * ndim,
        seed=0,
        evaluate=evaluation.evaluator_fun,
        # evaluate=D3evaluator,
        # validate=evaluation.valid_fun,
        numb_bees=args.bee_num,
        max_itrs=args.max_itrs,
        mode=args.mode,
        mutation=args.mutation,
        init_edges=edges,
        init_alpha=args.init_alpha,
    )
    cost = model.run()
    train_info = model.get_train_info()
    post_handler(train_info, new_args, evaluation.ground_truth_vector)


def post_handler(train_info, args, ground_truth_vector):
    cur_path = data_utils.DirUtils.get_output_path(args)
    print(f"{Fore.BLUE}Output Path{Fore.RESET} : {cur_path}")
    label = f"{args['bee_num']}_{args['max_itrs']}"
    solution = train_info["best_vector"]

    result_DAG_path = f"{cur_path}/{label}_result_DAG.png"
    print(f"{Fore.BLUE} result DAG saving into {Fore.RESET} : {result_DAG_path}")
    DG_Utils.general_show(solution, save_path=result_DAG_path, mode=args["mode"])

    middle_DAG_path = f"{cur_path}/{label}_mid_DAG.png"
    print(f"{Fore.BLUE} middle DAG saving into {Fore.RESET} : {middle_DAG_path}")
    DG_Utils.general_mid_show(solution, save_path=middle_DAG_path, mode=args["mode"])
    for solution_name, cur_solution in DataConversion.general_vector2solutions(
        solution, args["mode"]
    ):
        # print(cur_solution)
        train_info[f"{solution_name}_result"] = cur_solution
        train_info[f"{solution_name}_edges"] = DataConversion.vector2edges(cur_solution)
        eval_res = DAG_Utils.vector_evaluate(cur_solution, ground_truth_vector)
        train_info[f"{solution_name}_eval_result"] = eval_res
    print(f'{Fore.GREEN}Union eval{Fore.RESET} : {train_info["union_eval_result"]}')
    info_path = f"{cur_path}/{label}_save_info.json"
    print(f"{Fore.BLUE} Saving best bee info into {Fore.RESET} : {info_path}")

    loss_plot_save_path = f"{cur_path}/{label}_loss_plot.png"
    print(f"{Fore.BLUE}Saving loss plot into {Fore.RESET} : {loss_plot_save_path}")
    loss_plot_simple(
        train_info["convergence"]["best"],
        train_info["convergence"]["mean"],
        save_path=loss_plot_save_path,
        show=False,
    )
    data_utils.save_train_info(train_info, info_path)


def exp_with_file(train_data, gold_data, output_dir):
    new_args = {
        "bee_num": 40,
        "max_itrs": 50,
        "mode": "D1",
        "evaluator": "bic",
        "mutation": "mutation",
        "net_idx": 1,
        "init_alpha": 0.5,
        "init_edges": "",
        "task_name": "no_init_edges",
        "net_size": 10,
    }

    pass


def single_exp():
    args = get_args()

    new_args = {
        "bee_num": 40,
        "max_itrs": 50,
        "mode": "D1",
        "evaluator": "bic",
        "mutation": "mutation",
        "net_idx": 1,
        "init_alpha": 0.5,
        "init_edges": "",
        "task_name": "no_init_edges",
        "net_size": 10,
    }
    for i in range(1, 6):
        new_args["net_idx"] = i
        args.bee_num = new_args["bee_num"]
        args.max_itrs = new_args["max_itrs"]
        args.mode = new_args["mode"]
        args.evaluator = new_args["evaluator"]
        args.mutation = new_args["mutation"]
        args.net_idx = new_args["net_idx"]

        # args.init_edges = "output/data/edges_3.json"
        # args.task_name = "no_init_edges"
        experiments_run(args, new_args)


# Call the function to run experiments with different parameters
def gen_100_test():
    data = data_utils.read_size_100_time_series_data(1)
    print(data.shape)
    data = pd.DataFrame(
        data,
        columns=[f"A{i}" for i in range(1, 101)]
        + [f"B{i}" for i in range(1, 101)]
        + [f"C{i}" for i in range(1, 101)],
    )
    bic_scorer = BicScore(data)
    test_vector = np.zeros(900)
    edges = []
    gold_edges = data_utils.get_gold_standard(100, 1)
    # print(gold_edges)

    for _, row in gold_edges.iterrows():
        if row["value"] == 1:
            # print(row)
            src_idx = int(row["source"][1:])
            tar_idx = int(row["target"][1:])
            # edges.append((src_idx, tar_idx))
            # print(src_idx, "->", tar_idx)
            edges.append((f"A{src_idx}", f"C{tar_idx}"))
            edges.append((f"B{src_idx}", f"C{tar_idx}"))

            try:
                edges.append((f"C{src_idx}", f"C{tar_idx}"))
                model_try = BayesianNetwork(edges)
            except:
                edges.pop()
    print(f"end with {len(edges)} edges")
    model = BayesianNetwork(edges)
    score = bic_scorer.score(model)
    print(f"{score:.2e}")


if __name__ == "__main__":
    # run_experiments()  # 用来运行批量的实验
    single_exp()  # 用来运行单个的实验（debug）

# ---- END
